# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib
import math
import os
import sys
import random
import shutil
import tempfile
import time

import numpy as np
import six

from tensorflow.python.ops.hash_table.embedding import EmbeddingLookupHook
from tensorflow.python.framework import ops
from tensorflow.python.training.training_util import get_or_create_global_step
from tensorflow.python.ops.init_ops import *
from tensorflow.python.ops import gen_hash_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import function
from tensorflow.python.framework import dtypes
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.ops import logging_ops
from tensorflow.python.training.monitored_session import _RecoverableSession
from tensorflow.python.client import session
from tensorflow.python.util import tf_contextlib

_FILTER_CHECK_OPS = []
_FILTER_CHECK_FNS = []
_FILTER_CHECK_HASHTABLES = []

def _get_internal_sess(sess):
  while True:
    try:
      if isinstance(sess, session.Session):
        return sess
      if isinstance(sess, _RecoverableSession):
        if not sess._sess:
          sess._create_session()
      sess = sess._sess
    except:
      return sess

class HashFilterBase(EmbeddingLookupHook):
  UPDATE_OPS_COLLECTION_KEY = '_hash_filter_update_ops'
  FILTER_OPS_COLLECTION_KEY = '_hash_filter_filter_ops'
  CHECK_OPS_COLLECTION_KEY = '_hash_filter_check_ops'
  def __init__(self, block_size=1024*1024, parallel_num=4):
    del parallel_num
    self._block_size = block_size
    self._hash_table = None
    self._filter_set = set()

  def get_config(self):
    return {
      'block_size': self._block_size,
      }

  def on_embedding_lookup(self, ctx):
    update_ops_all = []
    filter_ops_all = []
    check_ops_all = []
    for i in range(len(ctx.partitions())):
      ht = ctx.partitions()[i]
      self._hash_table = ht
      update_ops = self._update(
        ctx.partitioned_keys()[i],
        ctx.ids()[i])
      if not isinstance(update_ops, (list, tuple)):
        update_ops = [update_ops]
      update_ops_all.extend(update_ops)
      if ht not in self._filter_set:
        filter_ops = self._filter(
          ht.device,
          ht)
        if not isinstance(filter_ops, (list, tuple)):
          filter_ops = [filter_ops]
        filter_ops_all.extend(filter_ops)
        self._filter_set.add(ht)
      check_ops, check_fn = self._check()
      check_ops_all.append(check_ops)
      _FILTER_CHECK_FNS.append(check_fn)
      _FILTER_CHECK_HASHTABLES.append(self._hash_table)
    update_op = control_flow_ops.group(update_ops_all, name="hash_filter_update_op")
    filter_op = control_flow_ops.group(filter_ops_all, name="hash_filter_filter_op")
    ops.add_to_collection(HashFilterBase.UPDATE_OPS_COLLECTION_KEY, update_op)
    ops.add_to_collection(HashFilterBase.FILTER_OPS_COLLECTION_KEY, filter_op)
    _FILTER_CHECK_OPS.extend(check_ops_all)

  def _update(self, keys, ids):
    return self.update(keys, ids)

  def _filter(self, device, hash_table):
    @function.Defun(dtypes.int64, dtypes.int64)
    def _filter_func_wrapper(keys, ids):
      return self.filter(keys, ids)
    return gen_hash_ops.hash_table_filter_op(
      hash_table.hash_table.handle,
      _filter_func_wrapper.captured_inputs,
      f=_filter_func_wrapper,
      block_size=self._block_size)

  def _check(self):
    keys, ids = self._hash_table.snapshot
    return self.check(keys, ids)

  """record informations needed in filter stage, return ops need to run
  Args:
    keys: embedding_lookup keys
    ids:  embedding_lookup ids generated by hash_table lookup
  Returns:
    update ops
  """
  def update(self, keys, ids):
    raise ValueError("never reach here")

  """filter logic
  Args:
    keys: keys in hash_table
    ids:  ids in hash_table
  Returns:
    a boolean tensor indicate which id should be filtered
  """
  def filter(self, keys, ids):
    raise ValueError("never reach here")

  """for test only
  Args:
    keys: keys in hash_table
    ids:  ids in hash_table
  Returns:
    check ops and check_fn
  """
  def check(self, keys, ids):
    return [], lambda: True

def update_op():
  return ops.get_collection(HashFilterBase.UPDATE_OPS_COLLECTION_KEY)

def filter_op():
  return ops.get_collection(HashFilterBase.FILTER_OPS_COLLECTION_KEY)

def check_op():
  return _FILTER_CHECK_OPS

class GlobalStepFilter(HashFilterBase):
  def __init__(self, filter_interval_steps):
    super(GlobalStepFilter, self).__init__()
    self._filter_interval_steps = filter_interval_steps
    self._filter_interval_steps_tensor = ops.convert_to_tensor(filter_interval_steps, dtype=dtypes.int64)
    self._global_step = get_or_create_global_step()

  def get_config(self):
    attr_dict =  super(GlobalStepFilter, self).get_config()
    attr_dict.update({
      'filter_interval_steps': self._filter_interval_steps
      })
    return attr_dict

  def update(self, keys, ids):
    step_slot = self._hash_table.get_or_create_slot(
        [1], dtypes.int64, 'update_step', initializer=Zeros(dtypes.int64))
    update_value = array_ops.reshape(array_ops.tile([self._global_step], [array_ops.shape(ids)[0]]), [-1, 1])
    default_value = ops.convert_to_tensor(0, dtype=dtypes.int64)
    return gen_hash_ops.tensible_variable_scatter_update(
      step_slot.handle,
      ids,
      update_value)

  def filter(self, keys, ids):
    step_slot = self._hash_table.get_slot('update_step')
    default_value = ops.convert_to_tensor(0, dtype=dtypes.int64)
    slot_value = gen_hash_ops.tensible_variable_gather(step_slot.handle, ids, default_value)
    filter_mask = math_ops.greater_equal(self._global_step - 1 - slot_value, self._filter_interval_steps_tensor)
    return array_ops.reshape(filter_mask, [-1])

  def check(self, keys, ids):
    step_slot = self._hash_table.get_slot('update_step')
    default_value = ops.convert_to_tensor(0, dtype=dtypes.int64)
    slot_value = gen_hash_ops.tensible_variable_gather(step_slot.handle, ids, default_value)
    return [slot_value, self._global_step], lambda x,y: (x >= y - self._filter_interval_steps).all()

class L2WeightFilter(HashFilterBase):
  def __init__(self, threshold):
    super(L2WeightFilter, self).__init__()
    self._threshold = ops.convert_to_tensor(threshold, dtypes.float32)

  def get_config(self):
    attr_dict =  super(L2WeightFilter, self).get_config()
    attr_dict.update({
      'threshold': self._threshold
      })
    return attr_dict

  def update(self, keys, ids):
    return []

  def filter(self, keys, ids):
    default_value = ops.convert_to_tensor(0, dtype=dtypes.float32)
    weight = gen_hash_ops.tensible_variable_gather(self._hash_table.handle, ids, default_value)
    l2_weight = math_ops.reduce_sum(math_ops.square(weight) * 0.5, axis=1)
    filter_mask = math_ops.less(l2_weight, self._threshold)
    return array_ops.reshape(filter_mask, [-1])

  def check(self, keys, ids):
    keys, ids = self._hash_table.snapshot
    default_value = ops.convert_to_tensor(0, dtype=dtypes.float32)
    weight = gen_hash_ops.tensible_variable_gather(self._hash_table.handle, ids, default_value)
    l2_weight = math_ops.reduce_sum(math_ops.square(weight) * 0.5, axis=1)
    return [l2_weight, self._threshold], lambda x,y:(x >= y).all()

class HashFilterHook(session_run_hook.SessionRunHook):
  def __init__(self, is_chief, interval=None, run_at_session_close=None):
    self._global_step = get_or_create_global_step()
    self._interval = interval
    self._run_at_session_close = run_at_session_close
    self._last_trigger_step = 0
    self._filter_ops = []
    self._update_ops = []
    self._is_chief = is_chief

  def after_create_session(self, sess, coord):
    self._filter_ops = filter_op()
    self._update_ops = update_op()

  def before_run(self, run_context):
    if self._interval:
      return SessionRunArgs([self._global_step] + self._update_ops)
    else:
      return SessionRunArgs(self._update_ops)

  def after_run(self, run_context, run_values):
    if not self._is_chief:
      return
    if self._interval:
      global_step = run_values.results[0]
      if global_step - self._last_trigger_step >= self._interval:
        _get_internal_sess(run_context.session).run(self._filter_ops)
        self._last_trigger_step = global_step

  def end(self, sess):
    if self._run_at_session_close and self._is_chief:
      _get_internal_sess(sess).run(self._filter_ops)

def filter_once(sess):
  _get_internal_sess(sess).run(filter_op())

def check(sess, verbose=True):
   values = _get_internal_sess(sess).run(check_op())
   check_ret = True
   for i in range(len(values)):
     if not _FILTER_CHECK_FNS[i](*(values[i])):
       print('====== HashTable[%s] HashFilter PostCheck Failed =====' % _FILTER_CHECK_HASHTABLES[i].name)
       np.set_printoptions(threshold=sys.maxsize)
       print(values[i])
       np.set_printoptions()
       check_ret = False
     elif verbose:
       print('===== HashTable[%s] HashFilter PostCheck Details ======' % _FILTER_CHECK_HASHTABLES[i].name)
       print(values[i])
   if check_ret:
     print('======= HashFilter post check passed ========')



class BlackListAdmit(EmbeddingLookupHook):
  def __init__(self):
    self.admit = {}
    self.token_name = None
    self.tokens = {}

  def get_config(self):
    return {}

  @tf_contextlib.contextmanager
  def token(self, name):
    try:
      latest_name = self.token_name
      self.token_name = name
      yield
    finally:
      self.token_name = latest_name

  def get_admit_strategy(self, ht):
    if ht not in self.admit:
      with ops.device(ht.device):
        self.admit[ht] = gen_hash_ops.black_list_hash_table_admit_strategy_op()
    if self.token_name is None:
      raise ValueError("No token name")
    self.tokens[ht] = self.token_name
    return self.admit[ht]

  def init(self, filename):
    dvs = {}
    for i in self.admit:
      if i.device not in dvs:
        dvs[i.device] = []
      dvs[i.device].append((i, self.tokens[i], self.admit[i]))
    print(dvs)
